#!/usr/bin/env python

import sys
import argparse
import os

class colors:
  GREEN = '\033[0;32m'
  NC = '\033[0m'

parser = argparse.ArgumentParser(description='This script takes NCBI protein accessions and returns a two-column \
                                              tab-delimited file with protein accessions and taxids. It requires the \
                                              "prot.accession2taxid" database (unzipped) that can be downloaded from here: \
                                              ftp://ftp.ncbi.nih.gov/pub/taxonomy/accession2taxid/prot.accession2taxid.gz\
                                              For version info, run `bit-version`.')

required = parser.add_argument_group('required arguments')

required.add_argument("-r", "--ref_map", help="prot_acc_to_taxid_map", action="store", dest="input_ref", required=True)
required.add_argument("-w", "--wanted_prot_accessions", help="Single-column file with protein accessions", action="store", dest="prot_accs", required=True)
parser.add_argument("-o", "--output_file", help='Output file of prot_acc and taxID', action="store", dest="file_out", default="Wanted_prot_accs_and_taxids.tsv")

if len(sys.argv)==1:
    parser.print_help(sys.stderr)
    sys.exit(0)

args = parser.parse_args()

wanted_accs = open(args.prot_accs, "r")

wanted_accs_set = set(line.strip() for line in wanted_accs)

output_file = open(args.file_out, "w")

num_found = 0

accs_found = []

output_file.write("prot_accession\ttaxid\n")

num = 0

print("\nNow beginning trek through reference mapping file.\n")

with open(args.input_ref) as refs:
    
    for line in refs:
        line = line.split("\t")
        num += 1
        
        if num % 1000000 == 0:
            mega_num = num / 1000000
            sys.stdout.write("\r  On line " + colors.GREEN + str(mega_num) + colors.NC + " million of prot_acc_to_taxid_map...")
            sys.stdout.flush()

        if line[1] in wanted_accs_set:
            output_file.write(line[1] + "\t" + line[2] + "\n")
            num_found += 1
            accs_found.append(line[1])

print("\n")

wanted_accs_list = list(wanted_accs_set)

print('  Adding in "NA"s for those protein accessions not found...\n')

for acc in wanted_accs_list:
    if acc not in accs_found:
        output_file.write(acc + "\tNA\n")

print(colors.GREEN + "  Done!" + colors.NC + "\n")
print("  You were looking for " + str(len(wanted_accs_list)) + " protein accessions.")
print("  " + str(num_found) + ' were found. The rest were given taxids of "NA\".')

output_file.close()
wanted_accs.close()
